from torchvision.datasets import ImageFolder
from .transforms import get_train_transforms, get_test_transforms

class FontDataset(ImageFolder):
    def __init__(self, root, is_train=True):
        transforms = get_train_transforms() if is_train else get_test_transforms()
        super().__init__(root=root, transform=transforms)